{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"../models\"\n",
    "image_path = r\"F:\\projects\\ImageAI\\data-images\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "Training with GPU\n",
      "==================================================\n",
      "Epoch 1/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  4.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 16.8996 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 71120032803498.6719 Accuracy: 0.4167\n",
      "Epoch 2/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 15.8721 Accuracy: 0.3235\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 688072444.0000 Accuracy: 0.4167\n",
      "Epoch 3/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 7.7935 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 358954.9310 Accuracy: 0.4167\n",
      "Epoch 4/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 2.4750 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 3636.8224 Accuracy: 0.5833\n",
      "Epoch 5/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 3.1157 Accuracy: 0.4118\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 450.6679 Accuracy: 0.3333\n",
      "Epoch 6/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 2.6678 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 243.7726 Accuracy: 0.3333\n",
      "Epoch 7/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 2.2962 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6382 Accuracy: 0.5833\n",
      "Epoch 8/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 2.6718 Accuracy: 0.3824\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 3.2020 Accuracy: 0.5000\n",
      "Epoch 9/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.7047 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7615 Accuracy: 0.5000\n",
      "Epoch 10/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.0887 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7489 Accuracy: 0.5000\n",
      "Epoch 11/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.0387 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7955 Accuracy: 0.5000\n",
      "Epoch 12/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.7051 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6334 Accuracy: 0.5833\n",
      "Epoch 13/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.7787 Accuracy: 0.4118\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7113 Accuracy: 0.4167\n",
      "Epoch 14/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 1.0196 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7387 Accuracy: 0.5833\n",
      "Epoch 15/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8587 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6956 Accuracy: 0.4167\n",
      "Epoch 16/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8895 Accuracy: 0.4412\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6974 Accuracy: 0.5000\n",
      "Epoch 17/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8003 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7048 Accuracy: 0.5833\n",
      "Epoch 18/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8540 Accuracy: 0.4412\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7049 Accuracy: 0.5000\n",
      "Epoch 19/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6594 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7129 Accuracy: 0.5833\n",
      "Epoch 20/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7757 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6831 Accuracy: 0.5833\n",
      "Epoch 21/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7873 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6807 Accuracy: 0.5833\n",
      "Epoch 22/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7951 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7177 Accuracy: 0.5833\n",
      "Epoch 23/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6488 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6970 Accuracy: 0.5000\n",
      "Epoch 24/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6765 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7144 Accuracy: 0.4167\n",
      "Epoch 25/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7337 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7092 Accuracy: 0.5833\n",
      "Epoch 26/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7300 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6802 Accuracy: 0.5833\n",
      "Epoch 27/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7148 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6721 Accuracy: 0.5833\n",
      "Epoch 28/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6904 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6791 Accuracy: 0.5833\n",
      "Epoch 29/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7107 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6776 Accuracy: 0.5833\n",
      "Epoch 30/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7312 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6803 Accuracy: 0.5833\n",
      "Epoch 31/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6260 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7073 Accuracy: 0.5833\n",
      "Epoch 32/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7365 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7157 Accuracy: 0.5833\n",
      "Epoch 33/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7192 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6815 Accuracy: 0.5833\n",
      "Epoch 34/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6952 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 22.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6776 Accuracy: 0.5833\n",
      "Epoch 35/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6806 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 22.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6688 Accuracy: 0.5833\n",
      "Epoch 36/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6488 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6844 Accuracy: 0.5833\n",
      "Epoch 37/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6018 Accuracy: 0.7059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.9153 Accuracy: 0.4167\n",
      "Epoch 38/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8475 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7768 Accuracy: 0.5833\n",
      "Epoch 39/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7040 Accuracy: 0.6765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6940 Accuracy: 0.5833\n",
      "Epoch 40/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7622 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6789 Accuracy: 0.5833\n",
      "Epoch 41/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7105 Accuracy: 0.4412\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6942 Accuracy: 0.5000\n",
      "Epoch 42/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7178 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6804 Accuracy: 0.5833\n",
      "Epoch 43/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7086 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6816 Accuracy: 0.5833\n",
      "Epoch 44/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6652 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7247 Accuracy: 0.5000\n",
      "Epoch 45/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6330 Accuracy: 0.6765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.8352 Accuracy: 0.4167\n",
      "Epoch 46/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6751 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.8076 Accuracy: 0.4167\n",
      "Epoch 47/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6763 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7010 Accuracy: 0.5833\n",
      "Epoch 48/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6855 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6846 Accuracy: 0.5833\n",
      "Epoch 49/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7186 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6841 Accuracy: 0.5833\n",
      "Epoch 50/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6517 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6939 Accuracy: 0.5833\n",
      "Epoch 51/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7701 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6821 Accuracy: 0.5833\n",
      "Epoch 52/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6746 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6821 Accuracy: 0.5833\n",
      "Epoch 53/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6770 Accuracy: 0.7059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6840 Accuracy: 0.5833\n",
      "Epoch 54/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7063 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7074 Accuracy: 0.5833\n",
      "Epoch 55/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6570 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6840 Accuracy: 0.5833\n",
      "Epoch 56/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7472 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6890 Accuracy: 0.5833\n",
      "Epoch 57/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6800 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7008 Accuracy: 0.5833\n",
      "Epoch 58/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7167 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6827 Accuracy: 0.5833\n",
      "Epoch 59/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6765 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6895 Accuracy: 0.5000\n",
      "Epoch 60/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6425 Accuracy: 0.7059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7032 Accuracy: 0.5833\n",
      "Epoch 61/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6012 Accuracy: 0.7059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7014 Accuracy: 0.5833\n",
      "Epoch 62/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7284 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6950 Accuracy: 0.5833\n",
      "Epoch 63/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6250 Accuracy: 0.6765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7138 Accuracy: 0.5833\n",
      "Epoch 64/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6674 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6876 Accuracy: 0.5833\n",
      "Epoch 65/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6482 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 21.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6824 Accuracy: 0.5833\n",
      "Epoch 66/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6446 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6965 Accuracy: 0.5833\n",
      "Epoch 67/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6756 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7026 Accuracy: 0.5833\n",
      "Epoch 68/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6591 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6712 Accuracy: 0.5000\n",
      "Epoch 69/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6022 Accuracy: 0.7647\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6920 Accuracy: 0.5833\n",
      "Epoch 70/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7085 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6952 Accuracy: 0.5833\n",
      "Epoch 71/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7464 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6962 Accuracy: 0.5833\n",
      "Epoch 72/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6688 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6668 Accuracy: 0.5833\n",
      "Epoch 73/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6088 Accuracy: 0.6765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6856 Accuracy: 0.5833\n",
      "Epoch 74/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6754 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6988 Accuracy: 0.5833\n",
      "Epoch 75/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.5496 Accuracy: 0.7647\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6791 Accuracy: 0.6667\n",
      "Epoch 76/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6370 Accuracy: 0.7647\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.8590 Accuracy: 0.5000\n",
      "Epoch 77/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7745 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7958 Accuracy: 0.4167\n",
      "Epoch 78/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6115 Accuracy: 0.6765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7699 Accuracy: 0.5000\n",
      "Epoch 79/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6640 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 1.2025 Accuracy: 0.3333\n",
      "Epoch 80/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.8342 Accuracy: 0.4118\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6966 Accuracy: 0.5833\n",
      "Epoch 81/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7582 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7305 Accuracy: 0.5833\n",
      "Epoch 82/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6930 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6823 Accuracy: 0.5833\n",
      "Epoch 83/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7358 Accuracy: 0.2941\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7105 Accuracy: 0.4167\n",
      "Epoch 84/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7223 Accuracy: 0.4706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6899 Accuracy: 0.5833\n",
      "Epoch 85/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7082 Accuracy: 0.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6795 Accuracy: 0.5833\n",
      "Epoch 86/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6977 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6811 Accuracy: 0.5833\n",
      "Epoch 87/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7084 Accuracy: 0.4412\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6947 Accuracy: 0.4167\n",
      "Epoch 88/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6728 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6828 Accuracy: 0.5833\n",
      "Epoch 89/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6990 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6953 Accuracy: 0.5833\n",
      "Epoch 90/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7046 Accuracy: 0.5294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6805 Accuracy: 0.5833\n",
      "Epoch 91/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6639 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 20.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6792 Accuracy: 0.5833\n",
      "Epoch 92/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6555 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6816 Accuracy: 0.5833\n",
      "Epoch 93/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6940 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6901 Accuracy: 0.5833\n",
      "Epoch 94/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6323 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6834 Accuracy: 0.5833\n",
      "Epoch 95/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7597 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.7016 Accuracy: 0.5833\n",
      "Epoch 96/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.7132 Accuracy: 0.5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 18.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6846 Accuracy: 0.5833\n",
      "Epoch 97/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6494 Accuracy: 0.6471\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6753 Accuracy: 0.5833\n",
      "Epoch 98/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6985 Accuracy: 0.5882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6793 Accuracy: 0.5833\n",
      "Epoch 99/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6737 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 17.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6834 Accuracy: 0.5833\n",
      "Epoch 100/100\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.6814 Accuracy: 0.6176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 19.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.6957 Accuracy: 0.5833\n",
      "Training completed in 2m 22s\n",
      "Best test accuracy: 0.6667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from imageai.Classification.Custom import ClassificationModelTrainer\n",
    "import os\n",
    "\n",
    "dirs = [\n",
    "    \"pets//train//dog//dog-train-images\",\n",
    "    \"pets//train//cat//cat-train-images\",\n",
    "    # \"pets//train//squirrel//squirrel-train-images\",\n",
    "    # \"pets//train//snake//snake-train-images \",\n",
    "    \"pets//test//dog//dog-test-images\",\n",
    "    \"pets//test//cat//cat-test-images\",\n",
    "    # \"pets//test//squirrel//squirrel-test-images\",\n",
    "    # \"pets//test//snake//snake-test-images\",\n",
    "]\n",
    "os.makedirs(\"custom-trained-pets\", exist_ok=True)\n",
    "for dir in dirs:\n",
    "    os.makedirs(dir, exist_ok=True)\n",
    "\n",
    "execution_path = os.getcwd()\n",
    "\n",
    "model_trainer = ClassificationModelTrainer()\n",
    "model_trainer.setModelTypeAsResNet50()\n",
    "model_trainer.setDataDirectory(\"pets\")\n",
    "model_trainer.trainModel(\n",
    "    # num_objects=4,\n",
    "    # num_experiments=100,\n",
    "    num_experiments=100,\n",
    "    # enhance_data=True,\n",
    "    # batch_size=32,\n",
    "    batch_size=5, # only use 10 ~ 20 cat photos and dog photos to train\n",
    "    # show_network_summary=True,\n",
    "    model_directory=\"custom-trained-pets\",\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.jpg is most like a dog, probability is [69.2646], result is True\n",
      "2.jpg is most like a dog, probability is [69.2649], result is True\n",
      "3.jpg is most like a cat, probability is [88.2442], result is True\n",
      "4.jpg is most like a dog, probability is [69.2497], result is False\n",
      "5.jpg is most like a dog, probability is [69.2648], result is False\n",
      "6.jpg is most like a dog, probability is [69.2579], result is True\n",
      "7.jpg is most like a dog, probability is [69.2648], result is True\n"
     ]
    }
   ],
   "source": [
    "from imageai.Classification.Custom import CustomImageClassification\n",
    "import os\n",
    "\n",
    "execution_path = os.getcwd()\n",
    "\n",
    "prediction = CustomImageClassification()\n",
    "prediction.setModelTypeAsResNet50()\n",
    "prediction.setModelPath(\n",
    "    os.path.join(\n",
    "        execution_path, \"custom-trained-pets/resnet50-pets-test_acc_0.66667_epoch-74.pt\"\n",
    "    )\n",
    ")\n",
    "prediction.setJsonPath(\n",
    "    os.path.join(execution_path, \"custom-trained-pets/pets_model_classes.json\")\n",
    ")\n",
    "prediction.loadModel()\n",
    "\n",
    "answers = [\"dog\", \"dog\", \"cat\", \"cat\", \"cat\", \"dog\", \"dog\"]\n",
    "# for 5, it is from test dataset\n",
    "for i in range(7):\n",
    "    predictions, probabilities = prediction.classifyImage(\n",
    "        os.path.join(execution_path, f\"custom-trained-pets/{i+1}.jpg\"),\n",
    "        result_count=2,  # only have two result\n",
    "    )\n",
    "\n",
    "    maxProbability = 0\n",
    "    pred = \"Unknow\"\n",
    "    for eachPrediction, eachProbability in zip(predictions, probabilities):\n",
    "        # print(eachPrediction, \" : \", eachProbability)\n",
    "        if eachProbability > maxProbability:\n",
    "            maxProbability = eachProbability\n",
    "            pred = eachPrediction\n",
    "    print(\n",
    "        f\"{i+1}.jpg is most like a {pred}, probability is [{maxProbability}], result is {answers[i] == pred}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv-gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
